{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial - Using your own Dataset\n",
    "\n",
    "In this notebook, we will see how to use your own `Dataset` within `pythae`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install the library\n",
    "%pip install pythae"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.datasets as datasets\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)\n",
    "\n",
    "train_dataset = mnist_trainset.data[:10000].reshape(-1, 1, 28, 28) / 255.\n",
    "train_targets = mnist_trainset.targets[:10000]\n",
    "eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.\n",
    "eval_targets = mnist_trainset.targets[-10000:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example\n",
    "Below is presented an example where we build a dataset inheriting from [`torchvision.datasets.ImageFolder`](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html#torchvision.datasets.ImageFolder) that iterate on images located in a folder. In particular, this loader will avoid loading all the data at once."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We will save the data in folders to mimic the desired example\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import imageio\n",
    "\n",
    "if not os.path.exists(\"data_folders\"):\n",
    "    os.mkdir(\"data_folders\")\n",
    "if not os.path.exists(\"data_folders/train\"):\n",
    "    os.mkdir(\"data_folders/train\")\n",
    "if not os.path.exists(\"data_folders/eval\"):\n",
    "    os.mkdir(\"data_folders/eval\")\n",
    "\n",
    "for i in range(len(train_dataset)):\n",
    "    img = 255.0*train_dataset[i][0].unsqueeze(-1)\n",
    "    img_folder = os.path.join(\"data_folders\", \"train\", f\"{train_targets[i]}\")\n",
    "    if not os.path.exists(img_folder):\n",
    "        os.mkdir(img_folder)\n",
    "    imageio.imwrite(os.path.join(img_folder, \"%08d.jpg\" % i), np.repeat(img, repeats=3, axis=-1).type(torch.uint8))\n",
    "\n",
    "for i in range(len(eval_dataset)):\n",
    "    img = 255.0*eval_dataset[i][0].unsqueeze(-1)\n",
    "    img_folder = os.path.join(\"data_folders\", \"eval\", f\"{eval_targets[i]}\")\n",
    "    if not os.path.exists(img_folder):\n",
    "        os.mkdir(img_folder)\n",
    "    imageio.imwrite(os.path.join(img_folder, \"%08d.jpg\" % i), np.repeat(img, repeats=3, axis=-1).type(torch.uint8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define transforms to be applied on the data when reloaded\n",
    "from torchvision import datasets, transforms\n",
    "data_transform = transforms.Compose([\n",
    "    transforms.Grayscale(num_output_channels=1),\n",
    "    transforms.ToTensor() # the data must be tensors\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define your `CustomDataset`\n",
    "\n",
    "In this example, we build a custom dataset inheriting from ImageFolder. The only thing you have to keep in mind when building a custom dataset that you want to use in `pythae` is that the `__getitem__` method must output a `DatasetOutput` instance containing at least `data` as key. If this is not the case, you will not be able to combine your Dataset with the `pipelines`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.data.datasets import DatasetOutput\n",
    "\n",
    "class MyCustomDataset(datasets.ImageFolder):\n",
    "\n",
    "    def __init__(self, root, transform=None, target_transform=None):\n",
    "        super().__init__(root=root, transform=transform, target_transform=target_transform)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        X, _ = super().__getitem__(index)\n",
    "\n",
    "        return DatasetOutput(\n",
    "            data=X\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = MyCustomDataset(\n",
    "    root=\"data_folders/train\",\n",
    "    transform=data_transform,\n",
    ")\n",
    "\n",
    "eval_dataset = MyCustomDataset(\n",
    "    root=\"data_folders/eval\", \n",
    "    transform=data_transform\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use your CustomDataset to train a `pythae.models`\n",
    "\n",
    "Now, the datasets can be passed to the `training_pipeline` to train any model implemented in `pythae`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.models import VAE, VAEConfig\n",
    "from pythae.trainers import BaseTrainerConfig\n",
    "from pythae.pipelines.training import TrainingPipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = BaseTrainerConfig(\n",
    "    output_dir='my_model',\n",
    "    learning_rate=1e-3,\n",
    "    per_device_train_batch_size=64,\n",
    "    per_device_eval_batch_size=64,\n",
    "    num_epochs=10, # Change this to train the model a bit more\n",
    ")\n",
    "\n",
    "\n",
    "model_config = VAEConfig(\n",
    "    input_dim=(1, 28, 28),\n",
    "    latent_dim=16\n",
    ")\n",
    "\n",
    "model = VAE(\n",
    "    model_config=model_config\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = TrainingPipeline(\n",
    "    training_config=config,\n",
    "    model=model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline(\n",
    "    train_data=train_dataset, # here we use the custom train dataset\n",
    "    eval_data=eval_dataset # here we use the custom eval dataset\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's have a look to the trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pythae.models import AutoModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "last_training = sorted(os.listdir('my_model'))[-1]\n",
    "trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.samplers import NormalSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create normal sampler\n",
    "normal_samper = NormalSampler(\n",
    "    model=trained_model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample\n",
    "gen_data = normal_samper.sample(\n",
    "    num_samples=25\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show results with normal sampler\n",
    "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')\n",
    "        axes[i][j].axis('off')\n",
    "plt.tight_layout(pad=0.)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "5e51c5ac46389dd7ba2bd8215d251ab84152720d3cad2ff91113d77594821aef"
  },
  "kernelspec": {
   "display_name": "Python 3.8.12 ('pythae')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
